-
Notifications
You must be signed in to change notification settings - Fork 565
[autoparallel] Add experimental config to enable autoparallel_asynctp #1772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
IvanKobzarev
wants to merge
26
commits into
main
Choose a base branch
from
IvanKobzarev/stack/2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop Hack an init_fn for llama3 and observe loss decreasing with autoparallel """ [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """ Adopt new autoparallel API with meta-init model Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop. Fixes to align with latest autoparallel Add inductor config knobs for comms optimizations to torchtitan Make inductor always run compile passes basically, this is an annoying workaround for debugging iteratively. 1- you run the model, it compiles, but something weird happens 2- you enable some logging or tlparse, rerun. but inductor decides not to run your pass anymore, its results are cached. since (2) has confused me horribly on more than one occasion, i just disable caching for now Drop hacky llama3_init_fn and use autop init_weights feature Relying on meta-pytorch/autoparallel#20, this lets us automatically apply a user's init_weights fn to the autoparallel model. Verified this works with `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` ``` [rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step: 1 loss: 8.1848 memory: 1.09GiB(1.14%) tps: 77 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step: 2 loss: 8.1619 memory: 1.15GiB(1.21%) tps: 48,138 tflops: 3.46 mfu: 0.35 % [rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step: 3 loss: 8.1140 memory: 1.15GiB(1.21%) tps: 88,440 tflops: 6.36 mfu: 0.64 % [rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step: 4 loss: 8.0099 memory: 1.15GiB(1.21%) tps: 82,626 tflops: 5.94 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step: 5 loss: 7.8928 memory: 1.15GiB(1.21%) tps: 81,594 tflops: 5.87 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step: 6 loss: 7.7758 memory: 1.15GiB(1.21%) tps: 79,607 tflops: 5.72 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step: 7 loss: 7.6221 memory: 1.15GiB(1.21%) tps: 81,448 tflops: 5.86 mfu: 0.59 % [rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step: 8 loss: 7.5578 memory: 1.15GiB(1.21%) tps: 79,732 tflops: 5.73 mfu: 0.58 % [rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step: 9 loss: 7.3851 memory: 1.15GiB(1.21%) tps: 85,655 tflops: 6.16 mfu: 0.62 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10 loss: 7.3361 memory: 1.15GiB(1.21%) tps: 81,855 tflops: 5.89 mfu: 0.60 % [rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete ``` fix lint
lets existing torchtitan knobs which govern DP/TP mesh creation and mesh size influence the sharding constraints of autoparallel, allowing it to support these different sharding configurations.
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
- fix passing of "none" (not None) to control bucketing passes
prints an (internal, vpn) only link for each profile trace file that's saved to manifold. Just search for 'trace' in your job logs on mast, and click one of the rank links. e.g. [trainer37|5]:[titan] 2025-08-07 14:21:01,227 - root - INFO - Finished dumping profiler traces in 5.22 seconds: https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/torchtrain_datasets/tree/outputs/torchtitan-64-whc-jv2j4mp/profile_trace/iteration_20/rank37_trace.json
This PR makes bucket sizes for all-gather and reduce-scatter to be of the same size for 1d FSDP.
IMO we should just add the loss in the model and let autoparallel parallelize it for us. But for now, let's follow how the other models are implemented
just add `--experimental.enable_simplefsdp_passes` and do not try to combine it with other `bucket_*` or `reorder_*` options.
This command should now run `CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel` However it doesn't actually do anything with autoparallel yet. Next step is to attach local_map to the model so that autoparallel can run.
Validated debugmodel llama3 works, but ds3 crashes becuase of `build_optimizers_with_moe_load_balancing` doing stuff that traverses the original model structure, only now its an AutoParallelModule which isn't compatible, we'll have to disable this optimization for now and think about what to do. Note: paths have changed, update your run commands: `CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` Failing (ds3): `CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name deepseekv3_auto_parallel`
as titled, this pr adds entry to simplefsdp's autobucketing pass in autoparallel. original code is in: pytorch/pytorch#160282 The main code for autobucketing pass will be added to autoparallel repo.
a16e72e
to
3ab4bf9
Compare
IvanKobzarev
added a commit
that referenced
this pull request
Sep 30, 2025
stack-info: PR: #1772, branch: IvanKobzarev/stack/2
IvanKobzarev
added a commit
that referenced
this pull request
Oct 6, 2025
stack-info: PR: #1772, branch: IvanKobzarev/stack/2
3ab4bf9
to
9082c19
Compare
stack-info: PR: #1772, branch: IvanKobzarev/stack/2
9082c19
to
45b15f6
Compare
IvanKobzarev
added a commit
to IvanKobzarev/torchtitan
that referenced
this pull request
Oct 8, 2025
stack-info: PR: pytorch#1772, branch: IvanKobzarev/stack/2
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[autoparallel] Add experimental config to enable autoparallel_asynctp